#!/usr/bin/env python3

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-ignore-all-errors

{%- set mdesc = "ssd" if ssd else "split" %}
{%- set sdesc = "_ssd" if ssd else "" %}

import torch
{%- if is_experimental_optimizer %}
import warnings
{%- endif %}
from .lookup_args{{ sdesc }} import *

{%- if is_fbcode %}

from fbgemm_gpu.utils.loader import load_torch_module, load_torch_module_bc
# Provide compatibility to downstream packages for eventual migration to the split training / inference packages
try:
    load_torch_module(
        "//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops_training_gpu",
        "//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops_cuda_training",
        "//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops_hip_training",
    )
    load_torch_module_bc(
        "//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops_training_cpu",
        "//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops_cpu_training",
    )
except Exception:
    torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops")
    torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops_cpu")

torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:embedding_inplace_update")
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:split_table_batched_embeddings")
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:cumem_utils")
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops")
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu")

{%- endif %}

{# This macro generates a code blob to pack Tensor arguments into a TensorList 
    as number of arguments for some optimizers exceed 64 #}
{%- macro pack_tensors(arg) %}
    {{ arg }}_list = [
        {{ arg }}.dev,
        {{ arg }}.uvm,
        {{ arg }}.placements,
        {{ arg }}.offsets,
    ] if not use_cpu else [
        {{ arg }}.host,
        {{ arg }}.placements,
        {{ arg }}.offsets,
    ] if {{ arg }} is not None else None
{%- endmacro %}

{# This macro generates a code blob to pack optim optional tensor into an optional TensorList.
    All optim optional tensors are packed together into `optim_tensor`. 
    This poses challenge to handle unpacking in autograd if we do per device (i.e, 3 for cpu and 4 for cuda).
    Hence, we pack unified args (i.e., 5 items) for readability and programmability.
 #}
{%- macro pack_optim_optional_tensor(arg) %}
    # using .extend fails torch script
    if {{ arg }} is None:
        optim_tensor.append(None)
        optim_tensor.append(None)
        optim_tensor.append(None)
        optim_tensor.append(None)
        optim_tensor.append(None)
    else:
        optim_tensor.append({{ arg }}.host)
        optim_tensor.append({{ arg }}.dev)
        optim_tensor.append({{ arg }}.uvm)
        optim_tensor.append({{ arg }}.placements)
        optim_tensor.append({{ arg }}.offsets)
{%- endmacro %}

{# This macro generates a code blob to pack auxillary arguments of the same type into a list.
    All arguments of type `t` are packed together into `aux_{t}` in the order specified by a dict `aux_args`.
    The dict is maintained in generate_backward_split.py and used for all templates related to packing/unpacking
    these arguments.
 #}
{%- macro pack_to_list(arg_type) %}
    {%- set annotate_type = ": List[Optional[torch.Tensor]]" if arg_type == "aux_tensor" else ": List[" + arg_type.split("_")[1] + "]" %}
    {{ arg_type }}{{ annotate_type }} = []
    {%- for var in aux_args[arg_type] %}
    assert "{{ var }}" in dict_{{ arg_type }}, (
        "{{ var }} must be in dict_{{ arg_type }}. "
        "Please check the frontend and backend version. "
    )
    {{ arg_type }}.append(dict_{{ arg_type }}["{{ var }}"])
    
    {%- endfor %}
{%- endmacro %}


{%- if is_prototype_optimizer %}
# Decorate the prototype optimizers which may be deprecated in the future with jit.ignore to avoid
# possible errors from torch.jit.script. 
# Note that backends can be removed but the lookup invoker is still needed for backward compatibility
@torch.jit.ignore
{%- endif %}
def invoke(
    common_args: CommonArgs,
    optimizer_args: OptimizerArgs,
    {%- if "momentum1" in args_pt2.unified_pt2.split_unpacked_arg_names %}
    momentum1: Momentum,
    {%- endif %}
    {%- if "momentum2" in args_pt2.unified_pt2.split_unpacked_arg_names %}
    momentum2: Momentum,
    {%- endif %}
    {%- if "prev_iter" in args_pt2.unified_pt2.split_unpacked_arg_names %}
    prev_iter: Momentum,
    {%- endif %}
    {%- if "row_counter" in args_pt2.unified_pt2.split_unpacked_arg_names and "row_counter" not in args_pt2.unified_pt2.split_saved_tensorlist_optional %}
    row_counter: Momentum,
    {%- endif %}
    {%- if "iter" in args_pt2.unified_pt2.split_unpacked_arg_names %}
    iter: int,
    {%- endif %}
    {%- if "max_counter" in args_pt2.unified_pt2.split_unpacked_arg_names %}
    max_counter: float,
    {%- endif %}
    {%- if "total_unique_indices" in args_pt2.unified_pt2.split_unpacked_arg_names %}
    total_unique_indices: int,
    {%- endif %}
    {%- if "iter" not in args_pt2.unified_pt2.split_unpacked_arg_names %}
    iter: int = 0,
    {%- endif %}
    apply_global_weight_decay: bool = False,
    {%- if "prev_iter_dev" not in args.split_function_arg_names %}
    # only pass prev_iter_dev since prev_iter is never created on UVM
    prev_iter_dev: Optional[torch.Tensor] = None,
    {%- endif %}
    gwd_lower_bound: float = 0.0,
    mixed_D: bool = True,
    {%- if "row_counter" in args_pt2.unified_pt2.split_saved_tensorlist_optional %}
    row_counter: Optional[Momentum] = None,
    {%- endif %}
) -> torch.Tensor:
    {%- if is_experimental_optimizer %}
    # By design, the warning only shows up once
    warnings.warn(
        f"""\033[93m
        [FBGEMM_GPU] NOTE: The training optimizer '{{ optimizer }}' is marked as
        EXPERIMENTAL and thus not optimized, in order to reduce code compilation
        times and build sizes!
        \033[0m"""
    )
    {%- endif %}
    # host_weights is only used for CPU training
    use_cpu = common_args.host_weights.numel() > 0
    vbe_metadata = common_args.vbe_metadata

    {%- if has_cpu_support and not has_gpu_support %}
    assert (use_cpu), "{{ optimizer }} has only CPU support. host_weights.numel() must be greater than 0."
    {%- endif %}
    {%- if ssd %}
    ssd_tensors = []
    {%- for tensor in ssd_tensors %}
    assert "{{ tensor }}" in common_args.ssd_tensors, (
        "{{ tensor }} must be in common_args.ssd_tensors. "
        "Please check the backend version"
    )
    ssd_tensors.append(common_args.ssd_tensors["{{ tensor }}"])
    {%- endfor %}
    {%- endif %}

    # pack weights
    weights = [
        common_args.dev_weights,
        common_args.uvm_weights,
        common_args.weights_placements,
        common_args.weights_offsets,
        common_args.lxu_cache_weights,
    ] if not use_cpu else [
        common_args.host_weights,
        common_args.weights_placements,
        common_args.weights_offsets,
    ]
    dict_aux_tensor: Dict[str, Optional[torch.Tensor]] = {
        "B_offsets": vbe_metadata.B_offsets,
        "vbe_output_offsets_feature_rank": vbe_metadata.output_offsets_feature_rank,
        "vbe_B_offsets_rank_per_feature": vbe_metadata.B_offsets_rank_per_feature,
        "lxu_cache_locations": common_args.lxu_cache_locations,
        "uvm_cache_stats": common_args.uvm_cache_stats,
    }

    dict_aux_int: Dict[str, int] = {
        "iter": iter, 
        "info_B_num_bits": common_args.info_B_num_bits, 
        "info_B_mask": common_args.info_B_mask,
    }
    
    dict_aux_float: Dict[str, float] = {
        "gwd_lower_bound": gwd_lower_bound,
    }

    dict_aux_bool: Dict[str, bool] = {
        "is_experimental_tbe": common_args.is_experimental,
        "use_uniq_cache_locations_bwd": common_args.use_uniq_cache_locations_bwd,
        "use_homogeneous_placements": common_args.use_homogeneous_placements,
        "apply_global_weight_decay": apply_global_weight_decay,
        "mixed_D": mixed_D,
        {%- if ssd %}
        "enable_optimizer_offloading": common_args.enable_optimizer_offloading,
        {%- endif %}
    }
    dict_optim_int: Dict[str, int] = {}
    dict_optim_float: Dict[str, float] = {}
    dict_optim_bool: Dict[str, bool] = {}

    # Explicitly pass only prev_iter_dev for global weight decay, unless it already exists in optim arg
    {%- if "prev_iter" not in args_pt2.unified_pt2.split_unpacked_arg_names %}
    dict_aux_tensor["prev_iter_dev"] = prev_iter_dev
    {%- else %}
    dict_aux_tensor["prev_iter_dev"] = prev_iter.dev
    {%- endif %}
    

    # optimizer_args
    {%- if optimizer == "none" %}
    dict_optim_int["total_hash_size"] = optimizer_args.total_hash_size
    {%- endif %} # if optimizer == none
    dict_aux_bool["gradient_clipping"] = optimizer_args.gradient_clipping
    dict_aux_float["max_gradient"] = optimizer_args.max_gradient
    dict_aux_bool["stochastic_rounding"] = optimizer_args.stochastic_rounding
    {%- if "eps" in args_pt2.unified_pt2.split_unpacked_arg_names %}
    dict_optim_float["eps"] = optimizer_args.eps
    {%- endif %}
    {%- if "beta1" in args_pt2.unified_pt2.split_unpacked_arg_names %}
    dict_optim_float["beta1"] = optimizer_args.beta1
    {%- endif %}
    {%- if "beta2" in args_pt2.unified_pt2.split_unpacked_arg_names %}
    dict_optim_float["beta2"] = optimizer_args.beta2
    {%- endif %}
    {%- if "weight_decay" in args_pt2.unified_pt2.split_unpacked_arg_names %}
    dict_optim_float["weight_decay"] = optimizer_args.weight_decay
    {%- endif %}
    {%- if "weight_decay_mode" in args_pt2.unified_pt2.split_unpacked_arg_names %}
    dict_optim_int["weight_decay_mode"] = optimizer_args.weight_decay_mode
    {%- endif %}
    {%- if "eta" in args_pt2.unified_pt2.split_unpacked_arg_names %}
    dict_optim_float["eta"] = optimizer_args.eta
    {%- endif %}
    {%- if "momentum" in args_pt2.unified_pt2.split_unpacked_arg_names %}
    dict_optim_float["momentum"] = optimizer_args.momentum
    {%- endif %}
    {%- if "counter_halflife" in args_pt2.unified_pt2.split_unpacked_arg_names %}
    dict_optim_int["counter_halflife"] = optimizer_args.counter_halflife
    {%- endif %}
    {%- if "adjustment_iter" in args_pt2.unified_pt2.split_unpacked_arg_names %}
    dict_optim_int["adjustment_iter"] = optimizer_args.adjustment_iter
    {%- endif %}
    {%- if "adjustment_ub" in args_pt2.unified_pt2.split_unpacked_arg_names %}
    dict_optim_float["adjustment_ub"] = optimizer_args.adjustment_ub
    {%- endif %}
    {%- if "learning_rate_mode" in args_pt2.unified_pt2.split_unpacked_arg_names %}
    dict_optim_int["learning_rate_mode"] = optimizer_args.learning_rate_mode
    {%- endif %}
    {%- if "grad_sum_decay" in args_pt2.unified_pt2.split_unpacked_arg_names %}
    dict_optim_int["grad_sum_decay"] = optimizer_args.grad_sum_decay
    {%- endif %}
    {%- if "tail_id_threshold" in args_pt2.unified_pt2.split_unpacked_arg_names %}
    dict_optim_float["tail_id_threshold"] = optimizer_args.tail_id_threshold
    {%- endif %}
    {%- if "is_tail_id_thresh_ratio" in args_pt2.unified_pt2.split_unpacked_arg_names %}
    dict_optim_int["is_tail_id_thresh_ratio"] = optimizer_args.is_tail_id_thresh_ratio
    {%- endif %}
    {%- if "weight_norm_coefficient" in args_pt2.unified_pt2.split_unpacked_arg_names %}
    dict_optim_float["weight_norm_coefficient"] = optimizer_args.weight_norm_coefficient
    {%- endif %}
    {%- if "lower_bound" in args_pt2.unified_pt2.split_unpacked_arg_names %}
    dict_optim_float["lower_bound"] = optimizer_args.lower_bound
    {%- endif %}
    {%- if "regularization_mode" in args_pt2.unified_pt2.split_unpacked_arg_names %}
    dict_optim_int["regularization_mode"] = optimizer_args.regularization_mode
    {%- endif %}
    {%- if "max_norm" in args_pt2.unified_pt2.split_unpacked_arg_names %}
    dict_optim_float["max_norm"] = optimizer_args.max_norm
    {%- endif %}
    {%- if "max_counter" in args_pt2.unified_pt2.split_unpacked_arg_names %}
    dict_optim_float["max_counter"] = max_counter
    {%- endif %}
    {%- if "use_rowwise_bias_correction" in args_pt2.unified_pt2.split_unpacked_arg_names %}
    dict_optim_bool["use_rowwise_bias_correction"] = optimizer_args.use_rowwise_bias_correction
    {%- endif %}

    {%- if "momentum1" in args_pt2.unified_pt2.split_unpacked_arg_names %}
    {{ pack_tensors("momentum1") }}
    {%- endif %}
    {%- if "momentum2" in args_pt2.unified_pt2.split_unpacked_arg_names %}
    {{ pack_tensors("momentum2") }}
    {%- endif %}
    {%- if "prev_iter" in args_pt2.unified_pt2.split_unpacked_arg_names %}
    {{ pack_tensors("prev_iter") }}
    {%- endif %}
    {%- if "row_counter" in args_pt2.unified_pt2.split_unpacked_arg_names  and "row_counter" not in args_pt2.unified_pt2.split_saved_tensorlist_optional %}
    {{ pack_tensors("row_counter") }}
    {%- endif %}
    {%- if "row_counter" in args_pt2.unified_pt2.split_saved_tensorlist_optional %}
    
    if optimizer_args.use_rowwise_bias_correction and row_counter is not None:
        row_counter_host = None # not supported on CPU
        row_counter_dev = row_counter.dev
        row_counter_uvm = row_counter.uvm
        row_counter_offsets = row_counter.offsets
        row_counter_placements = row_counter.placements   
    elif optimizer_args.use_rowwise_bias_correction:
        assert False, "`use_rowwise_bias_correction` is set, `row_counter` cannot be None"
    else:
        row_counter_host = None
        row_counter_dev = None
        row_counter_uvm = None
        row_counter_offsets = None
        row_counter_placements = None  
    {%- endif %}

    {{ pack_to_list("aux_tensor") }}
    {{ pack_to_list("aux_int") }}
    {{ pack_to_list("aux_float") }}
    {{ pack_to_list("aux_bool") }}

    {%- if "optim_tensor" in args_pt2.unified_pt2.split_function_arg_names %}
    optim_tensor: List[Optional[torch.Tensor]] = []
    # We cannot do list of optional tensorlist (optional tensorlist is Tensor?[]).
    # we need to pack optimizer optional tensors in a flatten manner.
    # We pack unified args (i.e., 5 items) since it's very confusing to pack/unpack per device (i.e, 3 for cpu and 4 for cuda)
    # e.g., if we have optim optional tensors x and y, the optim_tensor will look like
    # [x_host, x_dev, x_uvm, x_placements, x_offsets, y_host, y_dev, y_uvm, y_placements, y_offsets]
    # {{ args_pt2.unified_pt2.split_args_dict["optim_tensor"] }}
    {%- for name in args_pt2.unified_pt2.split_args_dict["optim_tensor"] %}
    {{ pack_optim_optional_tensor(name) }}
    {%- endfor %}
    {%- endif %}

    # optim_int
    {%- if "optim_int" in args_pt2.unified_pt2.split_function_arg_names %}
    optim_int: List[int] = []
    {%- for name in args_pt2.unified_pt2.split_args_dict["optim_int"] %}
    optim_int.append(dict_optim_int["{{ name }}"])
    {%- endfor %}
    {%- endif %}
    # optim_float
    # {{ args_pt2.unified_pt2.split_function_arg_names }}
    {%- if "optim_float" in args_pt2.unified_pt2.split_function_arg_names %}
    optim_float: List[float] = []
    {%- for name in args_pt2.unified_pt2.split_args_dict["optim_float"] %}
    optim_float.append(dict_optim_float["{{ name }}"])
    {%- endfor %}
    {%- endif %}
    # optim_bool
    {%- if "optim_bool" in args_pt2.unified_pt2.split_function_arg_names %}
    optim_bool: List[bool] = []
    {%- for name in args_pt2.unified_pt2.split_args_dict["optim_bool"] %}
    optim_bool.append(dict_optim_bool["{{ name }}"])
    {%- endfor %}
    {%- endif %} 

    return torch.ops.fbgemm.{{ mdesc }}_embedding_codegen_lookup_{{ optimizer }}_function_pt2(
        # common_args
        {%- if not dense %}
        placeholder_autograd_tensor=common_args.placeholder_autograd_tensor,
        {%- endif %}
        # weights
        weights=weights,
        D_offsets=common_args.D_offsets,
        total_D=common_args.total_D,
        max_D=common_args.max_D,
        hash_size_cumsum=common_args.hash_size_cumsum,
        total_hash_size_bits=common_args.total_hash_size_bits,
        indices=common_args.indices,
        offsets=common_args.offsets,
        pooling_mode=common_args.pooling_mode,
        indice_weights=common_args.indice_weights,
        feature_requires_grad=common_args.feature_requires_grad,
        output_dtype=common_args.output_dtype,
        {%- if ssd %}
        ssd_tensors=ssd_tensors,
        {%- endif %}
        # VBE metadata
        max_B=vbe_metadata.max_B,
        max_B_feature_rank=vbe_metadata.max_B_feature_rank,
        vbe_output_size=vbe_metadata.output_size,
        # aux_tensor
        aux_tensor=aux_tensor,
        # aux_int
        aux_int=aux_int,
        # aux_float
        aux_float=aux_float,
        # aux_bool
        aux_bool=aux_bool,
        {%- if "learning_rate_tensor" in args_pt2.unified_pt2.split_unpacked_arg_names %}
        learning_rate_tensor=common_args.learning_rate_tensor,
        {%- endif %}

        # momentum1
        {%- if "momentum1" in args_pt2.unified_pt2.split_unpacked_arg_names %}
        momentum1 = momentum1_list,
        {%- endif %}
        # momentum2
        {%- if "momentum2" in args_pt2.unified_pt2.split_unpacked_arg_names %}
        momentum2=momentum2_list,
        {%- endif %}
        # prev_iter
        {%- if "prev_iter" in args_pt2.unified_pt2.split_unpacked_arg_names %}
        prev_iter=prev_iter_list,
        {%- endif %}
        # row_counter
        {%- if "row_counter" in args_pt2.unified_pt2.split_unpacked_arg_names and "row_counter" not in args_pt2.unified_pt2.split_saved_tensorlist_optional %}
        row_counter=row_counter_list,
        {%- endif %}
        # optim_tensor
        {%- if "optim_tensor" in args_pt2.unified_pt2.split_function_arg_names %}
        optim_tensor=optim_tensor,
        {%- endif %}
        # optim_int
        {%- if "optim_int" in args_pt2.unified_pt2.split_function_arg_names %}
        optim_int=optim_int,
        {%- endif %}
        # optim_float
        {%- if "optim_float" in args_pt2.unified_pt2.split_function_arg_names %}
        optim_float=optim_float,
        {%- endif %}
        # optim_bool
        {%- if "optim_bool" in args_pt2.unified_pt2.split_function_arg_names %}
        optim_bool=optim_bool,
        {%- endif %}
        # optim symint args
        # total_unique_indices
        {%- if "total_unique_indices" in args.split_function_arg_names %}
        total_unique_indices=total_unique_indices,
        {%- endif %}
    )
